{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import time\n",
    "import numpy as np\n",
    "import os\n",
    "import threading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(path):\n",
    "    if path == 'network':\n",
    "        mnist = tf.keras.datasets.mnist\n",
    "        return mnist.load_data()\n",
    "    else:\n",
    "        # path = 'file\\\\mnist.npz'\n",
    "        f = np.load(path)\n",
    "        x_train, y_train = f['x_train'], f['y_train']\n",
    "        x_test, y_test = f['x_test'], f['y_test']\n",
    "        f.close()\n",
    "        return (x_train, y_train), (x_test, y_test)\n",
    "    \n",
    "(x_train, y_train), (x_test, y_test) = load_data('file\\\\mnist.npz')\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  wf>>>>>>>>>>>>>>>>>>>>> PID:62480  ident:56376\n",
      "WARNING:tensorflow:From <ipython-input-3-84e24f46f6d8>:43: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use Model.fit, which supports generators.\n",
      "  generator initiated\n",
      "WARNING:tensorflow:sample_weight modes were coerced from\n",
      "  ...\n",
      "    to  \n",
      "  ['...']\n",
      "Train for 20 steps\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 0  PID:62480  ident:47528\n",
      " 1/20 [>.............................] - ETA: 15s - loss: 2.4204 - accuracy: 0.0625  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1  PID:62480  ident:70604\n",
      " 3/20 [===>..........................] - ETA: 38s - loss: 2.1975 - accuracy: 0.2188  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 2  PID:62480  ident:7060\n",
      " 4/20 [=====>........................] - ETA: 39s - loss: 2.1036 - accuracy: 0.3125\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 4  PID:62480  ident:5817\n",
      " 6/20 [========>.....................] - ETA: 36s - loss: 1.9333 - accuracy: 0.4740  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 5  PID:62480  ident:5817\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 6  PID:62480  ident:70604\n",
      " 8/20 [===========>..................] - ETA: 32s - loss: 1.7945 - accuracy: 0.5508  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 7  PID:62480  ident:4752\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 8  PID:62480  ident:7060\n",
      "10/20 [==============>...............] - ETA: 27s - loss: 1.6676 - accuracy: 0.6187  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 9  PID:62480  ident:4752\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 10  PID:62480  ident:47528\n",
      "11/20 [===============>..............] - ETA: 25s - loss: 1.6076 - accuracy: 0.6477  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 11  PID:62480  ident:69016\n",
      "12/20 [=================>............] - ETA: 22s - loss: 1.5485 - accuracy: 0.6771  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 12  PID:62480  ident:69016\n",
      "13/20 [==================>...........] - ETA: 19s - loss: 1.4914 - accuracy: 0.6971  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 13  PID:62480  ident:69016\n",
      "14/20 [====================>.........] - ETA: 17s - loss: 1.4382 - accuracy: 0.7188  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 14  PID:62480  ident:69016\n",
      "15/20 [=====================>........] - ETA: 14s - loss: 1.3912 - accuracy: 0.7354  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 15  PID:62480  ident:70604\n",
      "16/20 [=======================>......] - ETA: 11s - loss: 1.3419 - accuracy: 0.7520  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 16  PID:62480  ident:70604\n",
      "18/20 [==========================>...] - ETA: 5s - loss: 1.2507 - accuracy: 0.7795   wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 17  PID:62480  ident:690\n",
      "19/20 [===========================>..] - ETA: 2s - loss: 1.2064 - accuracy: 0.7911\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 19  PID:62480  ident:475\n",
      "20/20 [==============================] - 58s 3s/step - loss: 1.1659 - accuracy: 0.8016\n",
      "Total used time: 61 \n"
     ]
    }
   ],
   "source": [
    "def gen():\n",
    "    print('  generator initiated')\n",
    "    idx = 0\n",
    "    while True:\n",
    "        yield x_train[:32], y_train[:32]\n",
    "        print('  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d  PID:%d  ident:%d' % (idx, os.getpid(), threading.currentThread().ident))\n",
    "        idx += 1\n",
    "        time.sleep(3)\n",
    " \n",
    "if __name__ == \"__main__\":\n",
    "    tr_gen = gen()\n",
    "\n",
    "    model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dropout(0.2),\n",
    "    tf.keras.layers.Dense(10, activation='softmax')\n",
    "    ])\n",
    " \n",
    "    print('  wf>>>>>>>>>>>>>>>>>>>>> PID:%d  ident:%d' % (os.getpid(), threading.currentThread().ident))\n",
    "    model.compile(optimizer='adam',\n",
    "                loss='sparse_categorical_crossentropy',\n",
    "                metrics=['accuracy'])\n",
    " \n",
    "    start_time = time.time()\n",
    "    model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10)\n",
    "    print('Total used time: %d '% (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  wf>>>>>>>>>>>>>>>>>>>>> PID:56548  ident:70232\n",
      "WARNING:tensorflow:From <ipython-input-1-cf0d04c31338>:45: Model.fit_generator (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use Model.fit, which supports generators.\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 0  PID:56548  ident:70232\n",
      "WARNING:tensorflow:sample_weight modes were coerced from\n",
      "  ...\n",
      "    to  \n",
      "  ['...']\n",
      "Train for 20 steps\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1740  PID:56548  ident:56836\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1682  PID:56548  ident:56836\n",
      " 1/20 [>.............................] - ETA: 1:06 - loss: 2.2855 - accuracy: 0.0625  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1435  PID:56548  ident:568\n",
      " 3/20 [===>..........................] - ETA: 53s - loss: 2.2717 - accuracy: 0.1042   wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1622  PID:56548  ident:1112\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1803  PID:56548  ident:63468\n",
      " 5/20 [======>.......................] - ETA: 46s - loss: 2.2200 - accuracy: 0.1562  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 844  PID:56548  ident:5683\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1116  PID:56548  ident:56836\n",
      " 6/20 [========>.....................] - ETA: 43s - loss: 2.1808 - accuracy: 0.2031  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1820  PID:56548  ident:11124\n",
      " 7/20 [=========>....................] - ETA: 39s - loss: 2.1536 - accuracy: 0.2277  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 576  PID:56548  ident:5683\n",
      " 9/20 [============>.................] - ETA: 33s - loss: 2.1190 - accuracy: 0.2535  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1464  PID:56548  ident:6634\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1794  PID:56548  ident:56836\n",
      "10/20 [==============>...............] - ETA: 30s - loss: 2.0852 - accuracy: 0.2781  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1407  PID:56548  ident:63468\n",
      "12/20 [=================>............] - ETA: 24s - loss: 2.0335 - accuracy: 0.3151  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1593  PID:56548  ident:6346\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1663  PID:56548  ident:56836\n",
      "13/20 [==================>...........] - ETA: 21s - loss: 1.9986 - accuracy: 0.3365  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 11  PID:56548  ident:1112\n",
      "14/20 [====================>.........] - ETA: 18s - loss: 1.9731 - accuracy: 0.3527  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 833  PID:56548  ident:56836\n",
      "16/20 [=======================>......] - ETA: 12s - loss: 1.9013 - accuracy: 0.3965  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1258  PID:56548  ident:6346\n",
      "  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 995  PID:56548  ident:11124\n",
      "17/20 [========================>.....] - ETA: 9s - loss: 1.8685 - accuracy: 0.4118   wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 1634  PID:56548  ident:66344\n",
      "18/20 [==========================>...] - ETA: 6s - loss: 1.8450 - accuracy: 0.4271  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 519  PID:56548  ident:663\n",
      "20/20 [==============================] - 61s 3s/step - loss: 1.7845 - accuracy: 0.4578>>>>>>>>>>>>>>>>>>>>>generator yielded a batch 89  PID:56548  ident:634\n",
      "\n",
      "Total used time: 66 \n"
     ]
    }
   ],
   "source": [
    "class MNISTSequence(tf.keras.utils.Sequence):\n",
    "    def __init__(self, x_set, y_set, batch_size):\n",
    "        self.x, self.y = x_set, y_set\n",
    "        self.batch_size = batch_size\n",
    "        \n",
    "    def __len__(self):        \n",
    "        return int(np.ceil(len(self.x) / float(self.batch_size)))\n",
    "    def __getitem__(self, idx):\n",
    "        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]\n",
    "        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]\n",
    "        print('  wf>>>>>>>>>>>>>>>>>>>>>generator yielded a batch %d  PID:%d  ident:%d' % (idx, os.getpid(), threading.currentThread().ident))\n",
    "        time.sleep(3)\n",
    "        return np.array(batch_x), np.array(batch_y)\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    tr_gen = MNISTSequence(x_train, y_train, 32)\n",
    "    \n",
    "    model = tf.keras.models.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dropout(0.2),\n",
    "    tf.keras.layers.Dense(10, activation='softmax')\n",
    "    ])\n",
    "    print('  wf>>>>>>>>>>>>>>>>>>>>> PID:%d  ident:%d' % (os.getpid(), threading.currentThread().ident))\n",
    "    model.compile(optimizer='adam',\n",
    "                loss='sparse_categorical_crossentropy',\n",
    "                metrics=['accuracy'])\n",
    "    start_time = time.time()\n",
    "    model.fit_generator(generator=tr_gen, steps_per_epoch=20, max_queue_size=10, use_multiprocessing=True, workers=0)\n",
    "    print('Total used time: %d '% (time.time() - start_time))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在我自己的PC上worker只能为0，否则会报错。但在我用的GPU服务器测试是可以运行的，并能减少大量时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
